{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_LEN = 300\n",
    "BATCH_SIZE = 32\n",
    "VOCAB_SIZE = 20000\n",
    "TF_RECORD_PATH = './imdb_train_fixed300.tfrecord'\n",
    "\n",
    "if not os.path.isfile(TF_RECORD_PATH):\n",
    "    (X_train, y_train), (_, _) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)\n",
    "    X_train = tf.keras.preprocessing.sequence.pad_sequences(X_train,\n",
    "                                                            MAX_LEN,\n",
    "                                                            padding='post',\n",
    "                                                            truncating='post')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████| 25000/25000 [00:45<00:00, 552.72it/s]\n"
     ]
    }
   ],
   "source": [
    "if not os.path.isfile(TF_RECORD_PATH):\n",
    "    writer = tf.python_io.TFRecordWriter(TF_RECORD_PATH)\n",
    "    for sent, label in tqdm(zip(X_train, y_train), total=len(X_train), ncols=70):\n",
    "        example = tf.train.Example(\n",
    "            features = tf.train.Features(\n",
    "                 feature = {\n",
    "                   'sent': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=sent)),\n",
    "                   'label': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=[label])),\n",
    "                   }))\n",
    "        serialized = example.SerializeToString()\n",
    "        writer.write(serialized)\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, 300) (?,)\n"
     ]
    }
   ],
   "source": [
    "def _parse_fn(example_proto):\n",
    "    parsed_feats = tf.parse_single_example(\n",
    "        example_proto,\n",
    "        features={\n",
    "            'sent': tf.FixedLenFeature([MAX_LEN], tf.int64),\n",
    "            'label': tf.FixedLenFeature([], tf.int64)\n",
    "        })\n",
    "    return parsed_feats['sent'], parsed_feats['label']\n",
    "\n",
    "dataset = tf.data.TFRecordDataset([TF_RECORD_PATH])\n",
    "dataset = dataset.map(_parse_fn)\n",
    "dataset = dataset.batch(BATCH_SIZE)\n",
    "iterator = dataset.make_one_shot_iterator()\n",
    "X_batch, y_batch = iterator.get_next()\n",
    "print(X_batch.get_shape(), y_batch.get_shape())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(32, 300) (32,)\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "x, y = sess.run([X_batch, y_batch])\n",
    "print(x.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
